{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "37dcaedc",
   "metadata": {},
   "source": [
    "# Neural Spline Flow on a Circular and a Normal Coordinate"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a3f13738",
   "metadata": {},
   "source": [
    "We aim to approximate a distribution having as circular and a normal coordinate. To construct such a case, let $x$ be the normal (unbound) coordinate follow a standard normal distribution, i.e. \n",
    "$$ p(x) = \\frac{1}{\\sqrt{2\\pi}} e^{-\\frac{1}{2} x ^ 2}.$$\n",
    "The circular random variable $\\phi$ follows a [von Mises distribution](https://en.wikipedia.org/wiki/Von_Mises_distribution) given by\n",
    "$$ p(\\phi|x) = \\frac{1}{2\\pi I_0(1)} e^{\\cos(\\phi-\\mu(x))}, $$\n",
    "where $I_0$ is the $0^\\text{th}$ order Bessel function of the first kind and we set $\\mu(x) = 3x$. Hence, our full target is given by\n",
    "$$ p(x, \\phi) = p(x)p(\\phi|x) = \\frac{1}{(2\\pi)^{\\frac{3}{2}} I_0(1)} e^{-\\frac{1}{2} x ^ 2 + \\cos(\\phi-3x)}. $$\n",
    "We use a neural spline flow that models the two coordinates accordingly."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "caf7dcb0",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ba0ffb9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import packages\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "import normflows as nf\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "from matplotlib import cm\n",
    "\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bd7d6b7f",
   "metadata": {},
   "source": [
    "This is our target $p(x, \\phi)$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15008e69",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set up target\n",
    "class GaussianVonMises(nf.distributions.Target):\n",
    "    def __init__(self):\n",
    "        super().__init__(prop_scale=torch.tensor(2 * np.pi), \n",
    "                         prop_shift=torch.tensor(-np.pi))\n",
    "        self.n_dims = 2\n",
    "        self.max_log_prob = -1.99\n",
    "        self.log_const = -1.5 * np.log(2 * np.pi) - np.log(np.i0(1))\n",
    "    \n",
    "    def log_prob(self, x):\n",
    "        return -0.5 * x[:, 0] ** 2 + torch.cos(x[:, 1] - 3 * x[:, 0]) + self.log_const"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c113dfe",
   "metadata": {},
   "outputs": [],
   "source": [
    "target = GaussianVonMises()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc02e62b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot target\n",
    "grid_size = 300\n",
    "xx, yy = torch.meshgrid(torch.linspace(-2.5, 2.5, grid_size), torch.linspace(-np.pi, np.pi, grid_size))\n",
    "zz = torch.cat([xx.unsqueeze(2), yy.unsqueeze(2)], 2).view(-1, 2)\n",
    "\n",
    "log_prob = target.log_prob(zz).view(*xx.shape)\n",
    "prob = torch.exp(log_prob)\n",
    "prob[torch.isnan(prob)] = 0\n",
    "\n",
    "plt.figure(figsize=(15, 15))\n",
    "plt.pcolormesh(yy, xx, prob.data.numpy(), cmap='coolwarm')\n",
    "plt.gca().set_aspect('equal', 'box')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab4234be",
   "metadata": {},
   "outputs": [],
   "source": [
    "base = nf.distributions.UniformGaussian(2, [1], torch.tensor([1., 2 * np.pi]))\n",
    "\n",
    "K = 12\n",
    "\n",
    "flow_layers = []\n",
    "for i in range(K):\n",
    "    flow_layers += [nf.flows.CircularAutoregressiveRationalQuadraticSpline(2, 1, 512, [1], num_bins=10,\n",
    "                                                                           tail_bound=torch.tensor([5., np.pi]),\n",
    "                                                                           permute_mask=True)]\n",
    "\n",
    "model = nf.NormalizingFlow(base, flow_layers, target)\n",
    "\n",
    "# Move model on GPU if available\n",
    "enable_cuda = True\n",
    "device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be7a29e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot model\n",
    "log_prob = model.log_prob(zz.to(device)).to('cpu').view(*xx.shape)\n",
    "prob = torch.exp(log_prob)\n",
    "prob[torch.isnan(prob)] = 0\n",
    "\n",
    "plt.figure(figsize=(15, 15))\n",
    "plt.pcolormesh(yy, xx, prob.data.numpy(), cmap='coolwarm')\n",
    "plt.gca().set_aspect('equal', 'box')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "92007234",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34efecc6",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Train model\n",
    "max_iter = 10000\n",
    "num_samples = 2 ** 14\n",
    "show_iter = 2500\n",
    "\n",
    "\n",
    "loss_hist = np.array([])\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)\n",
    "scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_iter)\n",
    "\n",
    "for it in tqdm(range(max_iter)):\n",
    "    optimizer.zero_grad()\n",
    "    \n",
    "    # Compute loss\n",
    "    loss = model.reverse_kld(num_samples)\n",
    "    \n",
    "    # Do backprop and optimizer step\n",
    "    if ~(torch.isnan(loss) | torch.isinf(loss)):\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    \n",
    "    # Log loss\n",
    "    loss_hist = np.append(loss_hist, loss.to('cpu').data.numpy())\n",
    "    \n",
    "    # Plot learned model\n",
    "    if (it + 1) % show_iter == 0:\n",
    "        model.eval()\n",
    "        with torch.no_grad():\n",
    "            log_prob = model.log_prob(zz.to(device)).to('cpu').view(*xx.shape)\n",
    "        model.train()\n",
    "        prob = torch.exp(log_prob)\n",
    "        prob[torch.isnan(prob)] = 0\n",
    "\n",
    "        plt.figure(figsize=(15, 15))\n",
    "        plt.pcolormesh(yy, xx, prob.data.numpy(), cmap='coolwarm')\n",
    "        plt.gca().set_aspect('equal', 'box')\n",
    "        plt.show()\n",
    "    \n",
    "    # Iterate scheduler\n",
    "    scheduler.step()\n",
    "\n",
    "# Plot loss\n",
    "plt.figure(figsize=(10, 10))\n",
    "plt.plot(loss_hist, label='loss')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9a787f2f",
   "metadata": {},
   "source": [
    "## Visualization of the Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a359e9c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 2D plot\n",
    "f, ax = plt.subplots(1, 2, sharey=True, figsize=(15, 7))\n",
    "\n",
    "log_prob = target.log_prob(zz).view(*xx.shape)\n",
    "prob = torch.exp(log_prob)\n",
    "prob[torch.isnan(prob)] = 0\n",
    "\n",
    "ax[0].pcolormesh(yy, xx, prob.data.numpy(), cmap='coolwarm')\n",
    "ax[0].set_aspect('equal', 'box')\n",
    "\n",
    "ax[0].set_xticks(ticks=[-np.pi, -np.pi/2, 0, np.pi/2, np.pi])\n",
    "ax[0].set_xticklabels(['$-\\pi$', r'$-\\frac{\\pi}{2}$', '$0$', r'$\\frac{\\pi}{2}$', '$\\pi$'],\n",
    "                      fontsize=20)\n",
    "ax[0].set_yticks(ticks=[-2, -1, 0, 1, 2])\n",
    "ax[0].set_yticklabels(['$-2$', '$-1$', '$0$', '$1$', '$2$'],\n",
    "                      fontsize=20)\n",
    "ax[0].set_xlabel('$\\phi$', fontsize=24)\n",
    "ax[0].set_ylabel('$x$', fontsize=24)\n",
    "\n",
    "ax[0].set_title('Target', fontsize=24)\n",
    "\n",
    "log_prob = model.log_prob(zz.to(device)).to('cpu').view(*xx.shape)\n",
    "prob = torch.exp(log_prob)\n",
    "prob[torch.isnan(prob)] = 0\n",
    "\n",
    "ax[1].pcolormesh(yy, xx, prob.data.numpy(), cmap='coolwarm')\n",
    "ax[1].set_aspect('equal', 'box')\n",
    "\n",
    "ax[1].set_xticks(ticks=[-np.pi, -np.pi/2, 0, np.pi/2, np.pi])\n",
    "ax[1].set_xticklabels(['$-\\pi$', r'$-\\frac{\\pi}{2}$', '$0$', r'$\\frac{\\pi}{2}$', '$\\pi$'],\n",
    "                      fontsize=20)\n",
    "ax[1].set_xlabel('$\\phi$', fontsize=24)\n",
    "\n",
    "ax[1].set_title('Neural Spline Flow', fontsize=24)\n",
    "\n",
    "plt.subplots_adjust(wspace=0.1)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff90c95a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 3D plot\n",
    "fig = plt.figure(figsize=(15, 7))\n",
    "ax1 = fig.add_subplot(1, 2, 1, projection='3d')\n",
    "ax2 = fig.add_subplot(1, 2, 2, projection='3d')\n",
    "\n",
    "phi = np.linspace(-np.pi, np.pi, grid_size)\n",
    "z = np.linspace(-2.5, 2.5, grid_size)\n",
    "\n",
    "# create the surface\n",
    "x = np.outer(np.ones(grid_size), np.cos(phi))\n",
    "y = np.outer(np.ones(grid_size), np.sin(phi))\n",
    "z = np.outer(z, np.ones(grid_size))\n",
    "\n",
    "# Target\n",
    "log_prob = target.log_prob(zz).view(*xx.shape)\n",
    "prob = torch.exp(log_prob)\n",
    "prob[torch.isnan(prob)] = 0\n",
    "\n",
    "prob_vis = prob / torch.max(prob)\n",
    "myheatmap = prob_vis.data.numpy()\n",
    "\n",
    "ax1._axis3don = False\n",
    "ax1.plot_surface(x, y, z, cstride=1, rstride=1, facecolors=cm.coolwarm(myheatmap), shade=False)\n",
    "\n",
    "ax1.set_title('Target', fontsize=24, y=0.97, pad=0)\n",
    "\n",
    "# Model\n",
    "log_prob = model.log_prob(zz.to(device)).to('cpu').view(*xx.shape)\n",
    "prob = torch.exp(log_prob)\n",
    "prob[torch.isnan(prob)] = 0\n",
    "\n",
    "prob_vis = prob / torch.max(prob)\n",
    "myheatmap = prob_vis.data.numpy()\n",
    "\n",
    "ax2._axis3don = False\n",
    "ax2.plot_surface(x, y, z, cstride=1, rstride=1, facecolors=cm.coolwarm(myheatmap), shade=False)\n",
    "\n",
    "t = ax2.set_title('Neural Spline Flow', fontsize=24, y=0.97, pad=0)\n",
    "\n",
    "plt.subplots_adjust(wspace=-0.4)\n",
    "\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
